#!/usr/bin/env python3

import datetime
import math
import operator
import os
import sys
import time

import isodate
from mpi4py import MPI
import numpy as np
import pandas as pd
import xlrd

sys.path.insert(0, os.path.dirname(__file__) + '/../../lib')
import forecast
import math_
import testable
import timeseries
import u


### Setup ###

l = u.l
c = u.c

ALLSET_32 = 0xffffffff

args = None
truth = None
tests = None

comm = MPI.COMM_WORLD
rank_ct = comm.Get_size()
rank = comm.Get_rank()
name = MPI.Get_processor_name()

ap = u.ArgumentParser()
gr = ap.default_group
gr.add_argument('tspath',
                metavar='TSPATH',
                help='path to time series dataset')
gr.add_argument('truth',
                metavar='TRUTHXLS',
                help='path to ground truth Excel file')
gr.add_argument('outdir',
                metavar='OUTDIR',
                help='directory for output')
gr.add_argument('--candidates',
                metavar='N',
                type=int,
                default=100,
                help='number of candidate articles for regression')
gr.add_argument('--freq',
                required=True,
                metavar='F',
                help='use truth with frequency F')
gr.add_argument('--horizon',
                nargs='+',
                required=True,
                metavar='N',
                type=int,
                help='forecast horizons to test (number of intervals)')
gr.add_argument('--limit',
                metavar='N',
                type=int,
                default=sys.maxsize,
                help='stop after this many time series per rank')
gr.add_argument('--profile',
                action='store_true',
                help='each rank drops a profile file')
gr.add_argument('--teststride',
                required=True,
                type=int,
                metavar='N',
                help='number of intervals between model builds')
gr.add_argument('--training',
                nargs='+',
                required=True,
                metavar='N',
                type=int,
                help='training periods to test (number of intervals)')
gr.add_argument('--shards',
                metavar='N',
                type=int,
                default=sys.maxsize,
                help='use this many shards, instead of all in dataset')



### Main ###

def main():
   info('starting')
   if (rank == 0):
      u.mkdir_f(args.outdir)
   if (args.profile):
      prof = u.Profiler()
   # load ground truth data
   global truth
   if (rank == 0):
      # only rank 0 actually loads the file, to avoid excess I/O
      truth = truth_load()
      info('found truth with %d outbreaks' % len(truth.columns))
   truth = comm.bcast(truth)
   # figure out what tests to do
   global tests
   tests = tests_enumerate()
   info('planning %d tests' % len(tests))

   # 1. Find candidate time seris
   #
   # 1a. Find top candidates within each shard; output is --
   #
   #     list: (Test,
   #            Priority_Queue:
   #               pri: int     r, correlation with ground truth
   #               val: Series  shifted/truncated training data, .name is URL)
   cands = candidates_read()

   # 1b. Find global top candidates; output is same as above.

   # done
   if (args.profile):
      filename = '%s/%d.prof' % (args.outdir, rank)
      prof.stop(filename)
   info('done')


### Classes ###


### Core functions ###

def candidates_read():
   ds = timeseries.Dataset_Pandas(args.tspath)
   ds.open_all()
   info('opened data set with %d fragments, %d shards, %d shard limit'
        % (len(ds.fragment_tags), ds.hashmod, args.shards))
   cands = list()
   for shard_i in range(0, min(ds.hashmod, args.shards), rank_ct):
      cands.append(shard_read(ds, shard_i))
   ds.close()
   return cands

def shard_read(ds, shard_i):
   info('processing shard %d' % shard_i)
   start = time.time()
   cands = [(ctx, u.Priority_Queue(args.candidates)) for ctx in tests]
   train_hits = None
   for (i, hits) in enumerate(ds.fetch_all(shard_i, last_only=False,
                                           normalize=True, resample=args.freq)):
      if (i >= args.limit):
         break
      for (ctx, pq) in cands:
         train_hits = ctx.training_data(hits, train_hits)  # shift and trim
         pq.add(abs(ctx.corr(train_hits)), train_hits)
   elapsed = time.time() - start
   eval_ct = (i + 1) * len(cands)
   if (eval_ct != 0):
      info('evaluated: %d articles, %d contexts, %d total; %s (%.0f µs/eval)'
           % (i + 1, len(cands), eval_ct, u.fmt_seconds(elapsed),
              elapsed * 1e6 / eval_ct))
   return cands


### MPI functions ###

def shufflereduce(op, data):
   '''This takes parallel, lists of items, one list per rank, reduces them using
      the user-provided operator, and distributes the results evenly among
      ranks. For example, you might have 6-item distributed lists of strings
      across 4 ranks:

        0  ['a', 'b', 'c', 'd', 'e', 'f']
        1  ['a', 'b', 'c', 'd', 'e', 'f']
        2  ['a', 'b', 'c', 'd', 'e', 'f']
        3  ['a', 'b', 'c', 'd', 'e', 'f']

      shufflereduce(operator.add) will return the following on each rank. The
      results are guaranteed to be evenly distributed. You can have fewer list
      items than ranks.

        0  ['aaaa', 'eeee']
        1  ['bbbb', 'ffff']
        2  ['cccc']
        3  ['dddd']

      Caveats:

        * List position defines what is to be shuffled. There is no key/value.
        * Untested for rank count not a power of two.
        * List length and item order are not preserved.
        * The shuffle happens in-place, so the passed list will be destroyed.

      More precise example:

        >>> ct = 14
        >>> start = [chr(ord('a')+i) for i in range(ct)]
        >>> start
        ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n']
        >>> end = [chr(ord('a')+i) * rank_ct for i in range(rank, ct, rank_ct)]
        >>> actual = shufflereduce(operator.add, start)
        >>> actual == end
        True'''
   assert math_.is_power_2(rank_ct)
   list_pad(data, rank_ct)
   for i in range(int(math.log2(rank_ct))):
      buddy = rank ^ (1 << i)  # flip bit i
      j_inc = 2**(i+1)
      destmask = ALLSET_32 & ~(ALLSET_32 << i+1)
      for j in range(0, len(data), j_inc):
         k_in =  j + (rank  & destmask)
         k_out = j + (buddy & destmask)
         in_ = comm.sendrecv(data[k_out], dest=buddy, source=buddy)
         data[k_out] = None
         data[k_in] = merge(op, in_, data[k_in])
   result = [i for i in data if i is not None]
   del data[:]  # finish clearing list so we don't keep references in it
   return result

def merge(op, a, b):
   """Merge a and b using op, unless (1) either is None, in which case return the
      other, or (2) both are None, in which case return None. The point is so
      user operators need only deal with exactly two items to merge.

      >>> merge(operator.add, 1, 2)
      3
      >>> merge(operator.add, 1, None)
      1
      >>> merge(operator.add, None, 2)
      2
      >>> merge(operator.add, None, None)
      >>> merge(operator.add, 1, [])
      Traceback (most recent call last):
        ...
      TypeError: unsupported operand type(s) for +: 'int' and 'list'"""
   if (a is None):
      return b
   elif (b is None):
      return a
   else:
      return op(a, b)

def list_pad(l, m):
   '''Pad list l in-place to the next highest multiple of m with Nones; also,
      return l.

      >>> list_pad([], 3)
      []
      >>> list_pad([1], 3)
      [1, None, None]
      >>> list_pad([1, 2], 3)
      [1, 2, None]
      >>> list_pad([1, 2, 3], 3)
      [1, 2, 3]
      >>> list_pad([1, 2, 3, 4], 3)
      [1, 2, 3, 4, None, None]
      >>> list_pad([], 0)
      Traceback (most recent call last):
        ...
      ValueError: can't pad list to length 0 <= 0'''
   if (m <= 0):
      raise ValueError("can't pad list to length %s <= 0" % m)
   l.extend([None] * (math.ceil(len(l)/m) * m - len(l)))
   return l


### Supporting functions ###

def debug(msg):
   l.debug('%d/%d: %s' % (rank, rank_ct, msg))

def info(msg):
   f = l.info if (rank == 0) else l.debug
   f('%d/%d: %s' % (rank, rank_ct, msg))

def tests_enumerate():
   'Return an iterable of possible contexts, given args and truth.'
   tests = list()
   for (obk, truedata) in truth.items():
      for tr in args.training:
         for ho in args.horizon:
            # The gist here is that we want to enumerate every possible now
            # that has sufficient time before for the training period and
            # sufficient time after for at least one test.
            for i in forecast.nows(len(truedata), tr, ho, args.teststride):
               tests.append(forecast.Context(truth, obk, tr, ho, i))
   return tests

def to_date(book, x):
   return datetime.datetime(*xlrd.xldate_as_tuple(x, book.datemode))

def truth_load():
   truth = dict()
   # We could use Pandas' Excel functions instead, to save a dependecy, but
   # this seems perhaps clearer? I may be wrong.
   book = xlrd.open_workbook(args.truth)
   for sheet in book.sheets():
      if (args.freq == sheet.name):
         headers = sheet.row_values(0)
         dates = sheet.col_values(0, start_rowx=1)
         date_start = to_date(book, dates[0])
         date_end = to_date(book, dates[-1])
         index = pd.period_range(date_start, periods=len(dates), freq=args.freq)
         assert (    date_start == index[0].to_timestamp(how='start')
                     and date_end == index[-1].to_timestamp(how='start'))
         for (i, context) in enumerate(headers[1:], start=1):
            data = ((j if j != '' else np.nan)
                    for j in sheet.col_values(i, start_rowx=1))
            truth[context] = pd.Series(data, index)
         args.ts_start = index[0].to_timestamp(how='start')
         args.ts_end = (index[-1].to_timestamp(how='end')
                        + datetime.timedelta(days=1)
                        - datetime.timedelta(microseconds=1))
         info('  periods:      %d' % len(index))
         info('  start:        %s (%s)' % (args.ts_start,
                                           args.ts_start.strftime('%A')))
         info('  end:          %s (%s)' % (args.ts_end,
                                           args.ts_end.strftime('%A')))
         break
   print(pd.DataFrame(truth, index=index),
         file=open('%s/dataframe.txt' % args.outdir, 'w'))
   return pd.DataFrame(truth, index=index)


### Bootstrap ###

try:
   if (__name__ == '__main__'):
      args = u.parse_args(ap)
      u.configure(args.config)
      u.logging_init('expmt')
      main()
except testable.Unittests_Only_Exception:
   testable.register()
